"""
T5 Tokenizer
^^^^^^^^^^^^^^^^^

"""


from textattack.models.tokenizers import AutoTokenizer


class T5Tokenizer(AutoTokenizer):
    """Uses the T5 tokenizer to convert an input for processing.

    For more information, please see the T5 paper, "Exploring the Limits of
    Transfer Learning with a Unified Text-to-Text Transformer".
    Appendix D contains information about the various tasks supported
    by T5.

    Supports the following modes:

    * summarization: summarize English text
    * english_to_german: translate English to German
    * english_to_french: translate English to French
    * english_to_romanian: translate English to Romanian
    """

    def __init__(self, mode="english_to_german", max_length=64):
        if mode == "english_to_german":
            self.tokenization_prefix = "translate English to German: "
        elif mode == "english_to_french":
            self.tokenization_prefix = "translate English to French: "
        elif mode == "english_to_romanian":
            self.tokenization_prefix = "translate English to Romanian: "
        elif mode == "summarization":
            self.tokenization_prefix = "summarize: "
        else:
            raise ValueError(f"Invalid t5 tokenizer mode {mode}.")

        super().__init__(tokenizer_path="t5-base", max_length=max_length)

    def encode(self, text):
        """Encodes a string into IDs of tokens.

        This prepares an input to be passed into T5.
        """
        if isinstance(text, tuple):
            if len(text) > 1:
                raise ValueError(
                    f"T5Tokenizer tuple inputs must have length 1; got {len(text)}"
                )
            text = text[0]
        if not isinstance(text, str):
            raise TypeError(f"T5Tokenizer expects `str` input, got {type(text)}")
        text_to_encode = self.tokenization_prefix + text
        return super().encode(text_to_encode)

    def batch_encode(self, input_text_list):
        new_input_text_list = []
        for text in input_text_list:
            if isinstance(text, tuple):
                if len(text) > 1:
                    raise ValueError(
                        f"T5Tokenizer tuple inputs must have length 1; got {len(text)}"
                    )
                text = text[0]
            new_input_text_list.append(self.tokenization_prefix + text)

        return super().batch_encode(new_input_text_list)

    def decode(self, ids):
        """Converts IDs (typically generated by the model) back to a string."""
        return self.tokenizer.decode(ids)
